#include <linux/module.h>
// #include <linux/version.h>
#include <linux/init.h>  //_initdata
#include <linux/kallsyms.h>
#include <linux/lsm_hooks.h>
#include <linux/printk.h>
#include <linux/types.h>  // umode_t
// #include <include/asm/pgtable_types.h> // _PAGE_BIT_RW
// #include <actypes.h>  // u8

#define MAX_RO_PAGES 1024
static struct page* ro_pages[MAX_RO_PAGES];
static unsigned int ro_pages_len;

struct security_hook_heads probe_dummy_security_hook_heads;

#define MY_HOOK_INIT(HEAD, HOOK)                                               \
    {                                                                          \
        .head = &probe_dummy_security_hook_heads.HEAD, .hook = {.HEAD = HOOK } \
    }

/**
 * probe_security_bprm_committed_creds - Dummy function which does identical to
 * security_bprm_committed_creds() in security/security.c.
 *
 * @bprm: Pointer to "struct linux_binprm".
 *
 * Returns nothing.
 */
void probe_security_bprm_committed_creds(struct linux_binprm* bprm)
{
    do {
        struct security_hook_list* p;
        hlist_for_each_entry(
            p, &probe_dummy_security_hook_heads.bprm_committed_creds, list)
        {
            p->hook.bprm_committed_creds(bprm);
        }

    } while (0);
}

/**
 * probe_find_symbol - Find function's address from /proc/kallsyms .
 *
 * @keyline: Function to find.
 *
 * Returns address of specified function on success, NULL otherwise.
 */
static void* probe_find_symbol(const char* keyline)
{
    char  buf[128] = {};
    char* cp;

    strncpy(buf, keyline + 1, sizeof(buf) - 1);
    cp = strchr(buf, '\n');
    if (cp) *cp = '\0';
    return (void*)kallsyms_lookup_name(buf);
}

/**
 * probe_find_variable - Find variable's address using dummy.
 *
 * @function: Pointer to dummy function's entry point.
 * @addr:     Address of the variable which is used within @function.
 * @symbol:   Name of symbol to resolve.
 *
 * This trick depends on below assumptions.
 *
 * (1) @addr is found within 128 bytes from @function, even if additional
 *     code (e.g. debug symbols) is added.
 * (2) It is safe to read 128 bytes from @function.
 * (3) @addr != Byte code except @addr.
 */
static void* probe_find_variable(void* function, unsigned long addr,
                                 const char* symbol)
{
    int i;
    u8* base;
    u8* cp;

    if (*symbol == ' ') base = probe_find_symbol(symbol);

    if (!base) return NULL;

    /* First, assume absolute adressing mode is used. */
    cp = function;
    for (i = 0; i < 128; i++) {
        if (*(unsigned long*)cp == addr) return base + i;
        cp++;
    }
    /* Next, assume PC-relative addressing mode is used. */
    cp = function;
    for (i = 0; i < 128; i++) {
        if ((unsigned long)(cp + sizeof(int) + *(int*)cp) == addr) {
            static void* cp4ret;

            cp = base + i;
            cp += sizeof(int) + *(int*)cp;
            cp4ret = cp;
            return &cp4ret;
        }
        cp++;
    }
    cp = function;
    for (i = 0; i < 128; i++) {
        if ((unsigned long)(long)(*(int*)cp) == addr) {
            static void* cp4ret;

            cp     = base + i;
            cp     = (void*)(long)(*(int*)cp);
            cp4ret = cp;
            return &cp4ret;
        }
        cp++;
    }
    return NULL;
}

static void* check_function_address(void* ptr, char* symbol)
{
    static char buf[KSYM_SYMBOL_LEN];
    const int   len = strlen(symbol);

    if (!ptr) {
        printk(KERN_EMERG "Can't resolve %s().\n", symbol);
        return NULL;
    }
    snprintf(buf, sizeof(buf), "%pS", ptr);
    if (strncmp(buf, symbol, len) || strncmp(buf + len, "+0x0/", 5)) {
        printk(KERN_EMERG "Guessed %s is %s\n", symbol, buf);
        return NULL;
    }
    return ptr;
}

/**
 * probe_security_hook_heads - Find address of "struct security_hook_heads
 * security_hook_heads".
 *
 * Returns pointer to "struct security_hook_heads" on success, NULL otherwise.
 */
struct security_hook_heads* probe_security_hook_heads(void)
{
    const unsigned int offset =
        offsetof(struct security_hook_heads, bprm_committed_creds);
    void* cp;

    struct security_hook_heads* shh;
    struct security_hook_list*  entry;
    void* cap = probe_find_symbol(" cap_bprm_set_creds\n");

    /* Get location of cap_bprm_set_creds(). */
    cap = check_function_address(cap, "cap_bprm_set_creds");
    if (!cap) return NULL;
    /* Guess "struct security_hook_heads security_hook_heads;". */
    cp = probe_find_variable(
        probe_security_bprm_committed_creds,
        ((unsigned long)&probe_dummy_security_hook_heads) + offset,
        " security_bprm_committed_creds\n");
    if (!cp) {
        printk(KERN_EMERG "Can't resolve security_bprm_committed_creds().\n");
        return NULL;
    }

    /* This should be "struct security_hook_heads security_hook_heads;". */
    shh = ((void*)(*(unsigned long*)cp)) - offset;

    hlist_for_each_entry(
        entry, &shh->bprm_set_creds,
        list) if (entry->hook.bprm_set_creds == cap) return shh;

    printk(KERN_EMERG "Guessed security_hook_heads is 0x%lx\n",
           (unsigned long)shh);
    return NULL;
}

// int (*inode_mkdir)(struct inode *dir, struct dentry *dentry, umode_t mode);
int linx_inode_mkdir(struct inode* dir, struct dentry* dentry, umode_t mode)
{
    printk(KERN_EMERG "run linx_inode_mkdir\n");
    return 0;
}

// int (*inode_rmdir)(struct inode *dir, struct dentry *dentry);
int linx_inode_rmdir(struct inode* dir, struct dentry* dentry)
{
    printk(KERN_EMERG "run linx_inode_rmdir\n");
    return 0;
}

static struct security_hook_list linx_hooks[] = {
    MY_HOOK_INIT(inode_mkdir, linx_inode_mkdir),
    MY_HOOK_INIT(inode_rmdir, linx_inode_rmdir),
};

static int lsm_test_page_ro(void* addr)
{
    unsigned int i;
    int          unused;
    struct page* page;

    page = (struct page*)lookup_address((unsigned long)addr, &unused);
    if (!page) return 0;
    if (test_bit(_PAGE_BIT_RW, &(page->flags))) return 1;
    for (i = 0; i < ro_pages_len; i++)
        if (page == ro_pages[i]) return 1;
    if (ro_pages_len == MAX_RO_PAGES) return 0;
    ro_pages[ro_pages_len++] = page;
    return 1;
}

static int check_ro_pages(struct security_hook_heads* hooks)
{
    int i;

    struct hlist_head* list = &hooks->capable;

    if (!probe_kernel_write(list, list, sizeof(void*))) return 1;

    for (i = 0; i < ARRAY_SIZE(linx_hooks); i++) {
        struct hlist_head*         head = linx_hooks[i].head;
        struct security_hook_list* shp;

        if (!lsm_test_page_ro(&head->first)) return 0;
        hlist_for_each_entry(
            shp, head, list) if (!lsm_test_page_ro(&shp->list.next) ||
                                 !lsm_test_page_ro(&shp->list.pprev)) return 0;
    }
    return 1;
}

static inline void add_hook(struct security_hook_list* hook)
{
    hlist_add_tail_rcu(&hook->list, hook->head);
}

static int proc_init(void)
{
    int idx;
    // struct security_hook_heads *hooks = probe_security_hook_heads();
    struct security_hook_heads* hooks =
        (struct security_hook_heads*)kallsyms_lookup_name(
            "security_hook_heads");
    if (!hooks) {
        printk(KERN_EMERG "hooks search failed\n");
    }
    for (idx = 0; idx < ARRAY_SIZE(linx_hooks); idx++)
        linx_hooks[idx].head =
            ((void*)hooks) + ((unsigned long)linx_hooks[idx].head) -
            ((unsigned long)&probe_dummy_security_hook_heads);

    if (!check_ro_pages(hooks)) {
        printk(KERN_EMERG
               "Can't update security_hook_heads due to write protected. \
				Retry with rodata=0 kernel command line option added.\n");
        return -EINVAL;
    }

    for (idx = 0; idx < ro_pages_len; idx++)
        set_bit(_PAGE_BIT_RW, &(ro_pages[idx]->flags));

    for (idx = 0; idx < ARRAY_SIZE(linx_hooks); idx++)
        add_hook(&linx_hooks[idx]);

    for (idx = 0; idx < ro_pages_len; idx++)
        clear_bit(_PAGE_BIT_RW, &(ro_pages[idx]->flags));

    printk(KERN_EMERG "hook-test install\n");
    return 0;
}

static void proc_cleanup(void)
{
    // security_delete_hooks(my_hook_list, ARRAY_SIZE(my_hook_list));
    printk(KERN_EMERG "hook-test uninstall \n");
}

module_init(proc_init);
module_exit(proc_cleanup);
MODULE_LICENSE("GPL");